{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tflearn\n",
    "from tflearn.layers.conv import conv_2d,max_pool_2d\n",
    "from tflearn.layers.core import input_data,dropout,fully_connected\n",
    "from tflearn.layers.estimator import regression\n",
    "import numpy as np\n",
    "import cv2\n",
    "from sklearn.utils import shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load Images from Swing\n",
    "loadedImages = []\n",
    "for i in range(0, 1000):\n",
    "    image = cv2.imread('Dataset/SwingImages/swing_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    loadedImages.append(gray_image.reshape(89, 100, 1))\n",
    "\n",
    "#Load Images From Palm\n",
    "for i in range(0, 1000):\n",
    "    image = cv2.imread('Dataset/PalmImages/palm_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    loadedImages.append(gray_image.reshape(89, 100, 1))\n",
    "    \n",
    "#Load Images From Fist\n",
    "for i in range(0, 1000):\n",
    "    image = cv2.imread('Dataset/FistImages/fist_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    loadedImages.append(gray_image.reshape(89, 100, 1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create OutputVector\n",
    "\n",
    "outputVectors = []\n",
    "for i in range(0, 1000):\n",
    "    outputVectors.append([1, 0, 0])\n",
    "\n",
    "for i in range(0, 1000):\n",
    "    outputVectors.append([0, 1, 0])\n",
    "\n",
    "for i in range(0, 1000):\n",
    "    outputVectors.append([0, 0, 1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "testImages = []\n",
    "\n",
    "#Load Images for swing\n",
    "for i in range(0, 100):\n",
    "    image = cv2.imread('Dataset/SwingTest/swing_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    testImages.append(gray_image.reshape(89, 100, 1))\n",
    "\n",
    "#Load Images for Palm\n",
    "for i in range(0, 100):\n",
    "    image = cv2.imread('Dataset/PalmTest/palm_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    testImages.append(gray_image.reshape(89, 100, 1))\n",
    "    \n",
    "#Load Images for Fist\n",
    "for i in range(0, 100):\n",
    "    image = cv2.imread('Dataset/FistTest/fist_' + str(i) + '.png')\n",
    "    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n",
    "    testImages.append(gray_image.reshape(89, 100, 1))\n",
    "\n",
    "testLabels = []\n",
    "\n",
    "for i in range(0, 100):\n",
    "    testLabels.append([1, 0, 0])\n",
    "    \n",
    "for i in range(0, 100):\n",
    "    testLabels.append([0, 1, 0])\n",
    "\n",
    "for i in range(0, 100):\n",
    "    testLabels.append([0, 0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the CNN Model\n",
    "tf.reset_default_graph()\n",
    "convnet=input_data(shape=[None,89,100,1],name='input')\n",
    "convnet=conv_2d(convnet,32,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "convnet=conv_2d(convnet,64,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=conv_2d(convnet,128,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=conv_2d(convnet,256,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=conv_2d(convnet,256,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=conv_2d(convnet,128,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=conv_2d(convnet,64,2,activation='relu')\n",
    "convnet=max_pool_2d(convnet,2)\n",
    "\n",
    "convnet=fully_connected(convnet,1000,activation='relu')\n",
    "convnet=dropout(convnet,0.75)\n",
    "\n",
    "convnet=fully_connected(convnet,3,activation='softmax')\n",
    "\n",
    "convnet=regression(convnet,optimizer='adam',learning_rate=0.001,loss='categorical_crossentropy',name='regression')\n",
    "\n",
    "model=tflearn.DNN(convnet,tensorboard_verbose=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training Step: 120  | total loss: \u001b[1m\u001b[32m0.06086\u001b[0m\u001b[0m | time: 19.870s\n",
      "| Adam | epoch: 003 | loss: 0.06086 - acc: 0.9971 -- iter: 2944/3000\n",
      "Training Step: 121  | total loss: \u001b[1m\u001b[32m0.05477\u001b[0m\u001b[0m | time: 21.241s\n",
      "| Adam | epoch: 003 | loss: 0.05477 - acc: 0.9974 | val_loss: 0.27714 - val_acc: 0.9267 -- iter: 3000/3000\n",
      "--\n",
      "INFO:tensorflow:C:\\Users\\t-spsah\\Documents\\Hand-Gesture-Recognition-Using-Background-Elllimination-and-Convolution-Neural-Network\\TrainedModel\\GestureRecogModel.tfl is not in all_model_checkpoint_paths. Manually adding it.\n"
     ]
    }
   ],
   "source": [
    "# Shuffle Training Data\n",
    "loadedImages, outputVectors = shuffle(loadedImages, outputVectors, random_state=0)\n",
    "\n",
    "# Train model\n",
    "model.fit(loadedImages, outputVectors, n_epoch=50,\n",
    "           validation_set = (testImages, testLabels),\n",
    "           snapshot_step=100, show_metric=True, run_id='convnet_coursera')\n",
    "\n",
    "model.save(\"TrainedModel/GestureRecogModel.tfl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
